{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adding support for a new models\n",
    "\n",
    "Different models are tuned with different role prompt formats. If the model you are using is not already a subclass of `guidance.Model`, you can define your own new subclass with whatever role prompt format you want. Then you can use the guidance role tags and they will get translated into the correct prompt format."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Orca Mini Chat Example\n",
    "\n",
    "As an example of how this works, below we implement the [Orca Mini 3B](https://huggingface.co/pankajmathur/orca_mini_3b) prompt. The prompt looks like this:\n",
    "\n",
    "`### System:\\n{system}\\n\\n### User:\\n{instruction}\\n\\n### Response:\\n\"`\n",
    "\n",
    "Where `system` is the system prompt and `instruction` is the user input instructions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define a new OrcaChat class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from guidance import models\n",
    "\n",
    "# this is our new chat model that inherits from the TransformersChat model and redefines role starts and ends\n",
    "class OrcaChat(models.transformers.TransformersChat):\n",
    "    def get_role_start(self, role_name, **kwargs):\n",
    "        if role_name == \"system\":\n",
    "            return \"### System:\\n\"\n",
    "        elif role_name == \"user\":\n",
    "            if str(self).endswith(\"\\n\\n### User:\\n\"):\n",
    "                return \"\" # we don't need to start anything if we are starting with a top level unnested system tag\n",
    "            else:\n",
    "                return \"### System:\\n\"\n",
    "        else:\n",
    "            return \" \"\n",
    "\n",
    "    def get_role_end(self, role_name=None):\n",
    "        if role_name == \"system\":\n",
    "            return \"\\n\\n### User:\\n\"\n",
    "        elif role_name == \"user\":\n",
    "            return \"\\n\\n### Response:\\n\"\n",
    "        else:\n",
    "            return \" \""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e2eb7c444ba4d92a5f29593faa919e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "orca = OrcaChat('pankajmathur/orca_mini_3b', torch_dtype=torch.float16, device_map='auto')\n",
    "# orca = OrcaChat('gpt2', device_map='auto') # Can use a small mock model while iterating on the prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check output without indentation formatting\n",
    "\n",
    "When you designing the prompt format you will want to see exactly how roles are going into the mode without any pretty indentation. Setting `indent=False` allows you to do this (and so is useful during role format development)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; vertical-align: middle; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'><span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>### System:\n",
       "</span>You are a cat expert.<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>\n",
       "\n",
       "### User:\n",
       "</span><span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'></span>What are the smallest cats?<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>\n",
       "\n",
       "### Response:\n",
       "</span><span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'> </span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>The</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> smallest</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> cats</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> are</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> the</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> dwar</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>f</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> cats</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>,</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> which</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> include</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> the</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> following</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> bre</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>eds</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>:</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>1</span><span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'> </span></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from guidance import gen, system, user, assistant, indent_roles\n",
    "\n",
    "with indent_roles(False):\n",
    "    with system():\n",
    "        lm = orca + \"You are a cat expert.\"\n",
    "\n",
    "    with user():\n",
    "        lm += \"What are the smallest cats?\"\n",
    "\n",
    "    with assistant():\n",
    "        lm += gen(\"answer\", stop=\".\", max_tokens=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that you can also just print the string representation of the model for a truely raw view of the model's context:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "### System:\n",
      "You are a cat expert.\n",
      "\n",
      "### User:\n",
      "What are the smallest cats?\n",
      "\n",
      "### Response:\n",
      " The smallest cats are the dwarf cats, which include the following breeds:\n",
      "\n",
      "1 \n"
     ]
    }
   ],
   "source": [
    "print(str(lm))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also change just a single role tag:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; vertical-align: middle; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'><span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>### System:\n",
       "</span>You are a cat expert.<span style='background-color: rgba(255, 180, 0, 0.3); border-radius: 3px;'>\n",
       "\n",
       "### User:\n",
       "</span><div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2);  justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>user</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'>What are the smallest cats?</div></div><div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2);  justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>assistant</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>The</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> smallest</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> cats</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> are</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> the</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> dwar</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>f</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> cats</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>,</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> which</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> include</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> the</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> following</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> bre</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>eds</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>:</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>1</span></div></div></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with indent_roles(False), system():\n",
    "    lm = orca + \"You are a cat expert.\"\n",
    "\n",
    "with user():\n",
    "    lm += \"What are the smallest cats?\"\n",
    "\n",
    "with assistant():\n",
    "    lm += gen(\"answer\", stop=\".\", max_tokens=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Normal use\n",
    "\n",
    "When you are satisfied with the correctness of your role formatting you can then use the model like normal:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style='margin: 0px; padding: 0px; vertical-align: middle; padding-left: 8px; margin-left: -8px; border-radius: 0px; border-left: 1px solid rgba(127, 127, 127, 0.2); white-space: pre-wrap; font-family: ColfaxAI, Arial; font-size: 15px; line-height: 23px;'><div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2);  justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>system</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'>You are a cat expert.</div></div><div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2);  justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>user</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'>What are the smallest cats?</div></div><div style='display: flex; border-bottom: 1px solid rgba(127, 127, 127, 0.2);  justify-content: center; align-items: center;'><div style='flex: 0 0 80px; opacity: 0.5;'>assistant</div><div style='flex-grow: 1; padding: 5px; padding-top: 10px; padding-bottom: 10px; margin-top: 0px; white-space: pre-wrap; margin-bottom: 0px;'><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>The</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> smallest</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> cats</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> are</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> the</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> dwar</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>f</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> cats</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>,</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> which</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> include</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> the</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> following</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'> bre</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>eds</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>:</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>\n",
       "</span><span style='background-color: rgba(0.0, 165.0, 0, 0.15); border-radius: 3px;' title='1.0'>1</span></div></div></pre>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "with system():\n",
    "    lm = orca + \"You are a cat expert.\"\n",
    "\n",
    "with user():\n",
    "    lm += \"What are the smallest cats?\"\n",
    "\n",
    "with assistant():\n",
    "    lm += gen(\"answer\", stop=\".\", max_tokens=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<hr style=\"height: 1px; opacity: 0.5; border: none; background: #cccccc;\">\n",
    "<div style=\"text-align: center; opacity: 0.5\">Have an idea for more helpful examples? Pull requests that add to this documentation notebook are encouraged!</div>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
